{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Secure sorting networks explained\n",
    "\n",
    "In this notebook, we develop some MPC protocols for securely sorting lists of secret-shared numbers. Concretely, we will show how to define functions sorting lists of secure MPyC integers into ascending order. The values represented by the secure integers and their relative order should remain completely secret.\n",
    "\n",
    "The explanation below assumes some basic familiarity with the MPyC framework for secure computation. Our main goal is to show how existing Python code for (oblivious) sorting can be used to implement a secure MPC sorting protocol using the `mpyc` package. The modifications to the existing code are very limited.\n",
    "\n",
    "## Sorting networks\n",
    "\n",
    "[Sorting networks](https://en.wikipedia.org/wiki/Sorting_network) are a classical type of comparison-based sorting algorithms. The basic operation (or, gate) in a sorting network is the *compare&swap* operation, which puts any two list elements $x[i]$ and $x[j]$, $i<j$, in ascending order. That is, only if $x[i]>x[j]$, elements $x[i]$ and $x[j]$ are swapped, and otherwise the compare&swap operation leaves the list unchanged. \n",
    "\n",
    "A sorting network specifies the exact sequence of compare&swap operations to be applied to a list of a given length $n$. The particular sequence depends only on $n$, the length of the input list. Even when the input list is already in ascending order, the sorting network will perform exactly as many---and actually the same---compare&swap operations as when the input list would be in descending order. \n",
    "\n",
    "For example, to sort a list of three numbers, one needs to perform three compare&swap operations with indices $(i,j)$ equal to $(0,1)$, then $(1,2)$, and finally once more $(0,1)$.\n",
    "\n",
    "Below, we will use odd-even merge sort and bitonic sort, which are two well-known practical sorting networks. \n",
    "\n",
    "## MPyC setup\n",
    "\n",
    "A simple MPyC setup using 32-bit (default) secure MPyC integers suffices for the purpose of this demonstration.\n",
    "\n",
    "At this point we also import the Python `traceback` module for later use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-10-23 09:40:25,563 Start MPyC runtime v0.4.2\n"
     ]
    }
   ],
   "source": [
    "from mpyc.runtime import mpc    # load MPyC\n",
    "secint = mpc.SecInt()           # 32-bit secure MPyC integers\n",
    "mpc.run(mpc.start())            # required only when run with multiple parties\n",
    "import traceback                # to show some suppressed error messages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Odd-even merge sort\n",
    "\n",
    "Odd-even merge sort is an elegant, but somewhat intricate, sorting network. The details are nicely explained in the Wikipedia article [Batcher's Odd-Even Mergesort](https://en.wikipedia.org/wiki/Batcher_odd–even_mergesort). \n",
    "\n",
    "For our purposes, however, there is no need to understand exactly how this particular sorting network works. The only thing that we need to do is to grab the following  [example Python code](https://en.wikipedia.org/wiki/Batcher_odd–even_mergesort#Example_code) from this Wikipedia article."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def oddeven_merge(lo, hi, r):\n",
    "    step = r * 2\n",
    "    if step < hi - lo:\n",
    "        yield from oddeven_merge(lo, hi, step)\n",
    "        yield from oddeven_merge(lo + r, hi, step)\n",
    "        yield from [(i, i + r) for i in range(lo + r, hi - r, step)]\n",
    "    else:\n",
    "        yield (lo, lo + r)\n",
    "\n",
    "def oddeven_merge_sort_range(lo, hi):\n",
    "    \"\"\" sort the part of x with indices between lo and hi.\n",
    "\n",
    "    Note: endpoints (lo and hi) are included.\n",
    "    \"\"\"\n",
    "    if (hi - lo) >= 1:\n",
    "        # if there is more than one element, split the input\n",
    "        # down the middle and first sort the first and second\n",
    "        # half, followed by merging them.\n",
    "        mid = lo + ((hi - lo) // 2)\n",
    "        yield from oddeven_merge_sort_range(lo, mid)\n",
    "        yield from oddeven_merge_sort_range(mid + 1, hi)\n",
    "        yield from oddeven_merge(lo, hi, 1)\n",
    "\n",
    "def oddeven_merge_sort(length):\n",
    "    \"\"\" \"length\" is the length of the list to be sorted.\n",
    "    Returns a list of pairs of indices starting with 0 \"\"\"\n",
    "    yield from oddeven_merge_sort_range(0, length - 1)\n",
    "\n",
    "def compare_and_swap(x, a, b):\n",
    "    if x[a] > x[b]:\n",
    "        x[a], x[b] = x[b], x[a]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We run the code on a simple example. Note that this code assumes that the length of the input list is an integral power of two."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 2, 3, 4, 5, 6, 7, 8]\n"
     ]
    }
   ],
   "source": [
    "x = [2, 4, 3, 5, 6, 1, 7, 8]\n",
    "for i in oddeven_merge_sort(len(x)): compare_and_swap(x, *i)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We try to run this code on a list of secure MPyC integers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"<ipython-input-4-a875f9c118a5>\", line 3, in <module>\n",
      "    for i in oddeven_merge_sort(len(x)): compare_and_swap(x, *i)\n",
      "  File \"<ipython-input-2-a99b48746a99>\", line 30, in compare_and_swap\n",
      "    if x[a] > x[b]:\n",
      "  File \"c:\\users\\berry\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\mpyc-0.4.2-py3.6.egg\\mpyc\\sectypes.py\", line 36, in __bool__\n",
      "    raise TypeError('cannot use secure type in Boolean expressions')\n",
      "TypeError: cannot use secure type in Boolean expressions\n"
     ]
    }
   ],
   "source": [
    "x = list(map(secint, [2, 4, 3, 5, 6, 1, 7, 8]))\n",
    "try:\n",
    "    for i in oddeven_merge_sort(len(x)): compare_and_swap(x, *i)\n",
    "except:\n",
    "    traceback.print_exc()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unsurprisingly, this does not work. We get an error because we cannot use a `secint` directly in the condition of an `if` statement. And, even if we could, we should not do so, as the particular branch of the `if` statement followed reveals information about the input!\n",
    "\n",
    "Therefore, the function `compare_and_swap` is modified (i) to hide whether elements of $x$ are swapped and (ii) to keep the values of the elements of $x$ hidden, even when these are swapped."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_and_swap(x, a, b):\n",
    "    c = x[a] > x[b]                  # secure comparison, secint c represents a secret-shared bit\n",
    "    d = c * (x[b] - x[a])            # secure subtraction\n",
    "    x[a], x[b] = x[a] + d, x[b] - d  # secure swap: x[a], x[b] swapped if only if c=1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now the code can be used to sort a list of secure MPyC integers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 2, 3, 4, 5, 6, 7, 8]\n"
     ]
    }
   ],
   "source": [
    "x = list(map(secint, [2, 4, 3, 5, 6, 1, 7, 8]))\n",
    "for i in oddeven_merge_sort(len(x)): compare_and_swap(x, *i)\n",
    "print(mpc.run(mpc.output(x)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bitonic sort\n",
    "\n",
    "For our next example, we consult the Wikipedia article [Bitonic Sorter](https://en.wikipedia.org/wiki/Bitonic_sorter).\n",
    "\n",
    "We apply the same approach, grabbing the [example Python code](https://en.wikipedia.org/wiki/Bitonic_sorter#Example_code) from the Wikipedia article, which is also designed to work for input lists whose length is an integral power of two."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bitonic_sort(up, x):\n",
    "    if len(x) <= 1:\n",
    "        return x\n",
    "    else: \n",
    "        first = bitonic_sort(True, x[:len(x) // 2])\n",
    "        second = bitonic_sort(False, x[len(x) // 2:])\n",
    "        return bitonic_merge(up, first + second)\n",
    "\n",
    "def bitonic_merge(up, x): \n",
    "    # assume input x is bitonic, and sorted list is returned \n",
    "    if len(x) == 1:\n",
    "        return x\n",
    "    else:\n",
    "        bitonic_compare(up, x)\n",
    "        first = bitonic_merge(up, x[:len(x) // 2])\n",
    "        second = bitonic_merge(up, x[len(x) // 2:])\n",
    "        return first + second\n",
    "\n",
    "def bitonic_compare(up, x):\n",
    "    dist = len(x) // 2\n",
    "    for i in range(dist):  \n",
    "        if (x[i] > x[i + dist]) == up:\n",
    "            x[i], x[i + dist] = x[i + dist], x[i] #swap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We run the code on the same example. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 2, 3, 4, 5, 6, 7, 8]\n",
      "[8, 7, 6, 5, 4, 3, 2, 1]\n"
     ]
    }
   ],
   "source": [
    "print(bitonic_sort(True, [2, 4, 3, 5, 6, 1, 7, 8]))\n",
    "print(bitonic_sort(False, [2, 4, 3, 5, 6, 1, 7, 8]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Running the code on a list of secure MPyC integers gives the same error as above. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"<ipython-input-9-4b604066647c>\", line 3, in <module>\n",
      "    bitonic_sort(True, x)\n",
      "  File \"<ipython-input-7-8a5c3e13ac49>\", line 5, in bitonic_sort\n",
      "    first = bitonic_sort(True, x[:len(x) // 2])\n",
      "  File \"<ipython-input-7-8a5c3e13ac49>\", line 5, in bitonic_sort\n",
      "    first = bitonic_sort(True, x[:len(x) // 2])\n",
      "  File \"<ipython-input-7-8a5c3e13ac49>\", line 7, in bitonic_sort\n",
      "    return bitonic_merge(up, first + second)\n",
      "  File \"<ipython-input-7-8a5c3e13ac49>\", line 14, in bitonic_merge\n",
      "    bitonic_compare(up, x)\n",
      "  File \"<ipython-input-7-8a5c3e13ac49>\", line 22, in bitonic_compare\n",
      "    if (x[i] > x[i + dist]) == up:\n",
      "  File \"c:\\users\\berry\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\mpyc-0.4.2-py3.6.egg\\mpyc\\sectypes.py\", line 36, in __bool__\n",
      "    raise TypeError('cannot use secure type in Boolean expressions')\n",
      "TypeError: cannot use secure type in Boolean expressions\n"
     ]
    }
   ],
   "source": [
    "x = list(map(secint, [2, 4, 3, 5, 6, 1, 7, 8]))\n",
    "try:\n",
    "    bitonic_sort(True, x)\n",
    "except:\n",
    "    traceback.print_exc()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This time we modify the function `bitonic_compare` as follows again to hide what is happening to the elements of $x$ being compared."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bitonic_compare(up, x):\n",
    "    dist = len(x) // 2\n",
    "    up = secint(up)                                    # convert public Boolean up into `secint` bit \n",
    "    for i in range(dist):\n",
    "        b = (x[i] > x[i + dist]) ^ ~up                 # secure xor of comparison bit and negated up\n",
    "        d = b * (x[i + dist] - x[i])                   # d = 0 or d = x[i + dist] - x[i]\n",
    "        x[i], x[i + dist] = x[i] + d, x[i + dist] - d  # secure swap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now the code can again be used to sort a list of secure MPyC integers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 2, 3, 4, 5, 6, 7, 8]\n"
     ]
    }
   ],
   "source": [
    "print(mpc.run(mpc.output(bitonic_sort(True, x))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-10-23 09:40:26,468 Stop MPyC runtime -- elapsed time: 0:00:00.900694\n"
     ]
    }
   ],
   "source": [
    "mpc.run(mpc.shutdown())   # required only when run with multiple parties"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Python script [sort.py](sort.py) shows how to do secure bitonic sort for lists of arbitrary length, adapted from this general [bitonic sorter](http://www.iti.fh-flensburg.de/lang/algorithmen/sortieren/bitonic/oddn.htm)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
